#include <cuda_runtime.h>
#include <stdio.h>

#include "test_bazel/test_cuda/bank-conflict/matmul.hpp"
#include "test_bazel/test_cuda/bank-conflict/timer.hpp"
#include "test_bazel/test_cuda/bank-conflict/utils.hpp"

int seed;
int main() {
  Timer timer;

  int width = 1 << 12; // 4,096
  int low = 0;
  int high = 1;
  int size = width * width;
  int blockSize = 16;
  bool statMem = true;
  char str[100];

  float *h_matM = (float *)malloc(size * sizeof(float));
  float *h_matN = (float *)malloc(size * sizeof(float));
  float *h_matP = (float *)malloc(size * sizeof(float));
  float *d_matP = (float *)malloc(size * sizeof(float));

  // seed = (unsigned)time(NULL);
  seed = 1;
  initMatrix(h_matM, size, low, high, seed);
  seed += 1;
  initMatrix(h_matN, size, low, high, seed);

  LOG("Input size is %d x %d", width, width);

  /* GPU warmup */
  timer.start_gpu();
  MatmulOnDevice(h_matM, h_matN, h_matP, width, blockSize);
  timer.stop_gpu();
  timer.duration_gpu("matmul in gpu(warmup)");

  /* GPU general implementation <<<256, 16>>>*/
  timer.start_gpu();
  MatmulOnDevice(h_matM, h_matN, d_matP, width, blockSize);
  timer.stop_gpu();
  std::sprintf(str, "matmul in gpu(general)");
  timer.duration_gpu(str);
  compareMat(h_matP, d_matP, size);

  /* GPU general implementation <<<256, 16>>>*/
  timer.start_gpu();
  MatmulSharedOnDevice(h_matM, h_matN, h_matP, width, blockSize, statMem);
  timer.stop_gpu();
  std::sprintf(str, "matmul in gpu(shared memory(static))");
  timer.duration_gpu(str);
  compareMat(h_matP, d_matP, size);

  /* GPU general implementation <<<256, 16>>>*/
  timer.start_gpu();
  MatmulSharedConflictOnDevice(h_matM, h_matN, d_matP, width, blockSize,
                               statMem);
  timer.stop_gpu();
  std::sprintf(str, "matmul in gpu(shared memory(static, bank conf))");
  timer.duration_gpu(str);
  compareMat(h_matP, d_matP, size);

  /* GPU general implementation <<<256, 16>>>*/
  timer.start_gpu();
  MatmulSharedConflictPadOnDevice(h_matM, h_matN, d_matP, width, blockSize,
                                  statMem);
  timer.stop_gpu();
  std::sprintf(str,
               "matmul in gpu(shared memory(static, pad resolve bank conf))");
  timer.duration_gpu(str);
  compareMat(h_matP, d_matP, size);

  /* GPU general implementation <<<256, 16>>>*/
  statMem = false;
  timer.start_gpu();
  MatmulSharedOnDevice(h_matM, h_matN, d_matP, width, blockSize, statMem);
  timer.stop_gpu();
  std::sprintf(str, "matmul in gpu(shared memory(dynamic))");
  timer.duration_gpu(str);
  compareMat(h_matP, d_matP, size);

  /* GPU general implementation <<<256, 16>>>*/
  statMem = false;
  timer.start_gpu();
  MatmulSharedConflictOnDevice(h_matM, h_matN, d_matP, width, blockSize,
                               statMem);
  timer.stop_gpu();
  std::sprintf(str, "matmul in gpu(shared memory(dynamic, bank conf)");
  timer.duration_gpu(str);
  compareMat(h_matP, d_matP, size);

  /* GPU general implementation <<<256, 16>>>*/
  statMem = false;
  timer.start_gpu();
  MatmulSharedConflictPadOnDevice(h_matM, h_matN, d_matP, width, blockSize,
                                  statMem);
  timer.stop_gpu();
  std::sprintf(str,
               "matmul in gpu(shared memory(dynamic, pad resolve bank conf))");
  timer.duration_gpu(str);
  compareMat(h_matP, d_matP, size);

  return 0;
}
